-
Notifications
You must be signed in to change notification settings - Fork 22
feat: implement OAuth2 token caching for Vertex AI authentication #15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
app/core/async_cache.py
Outdated
| ttl_seconds=3600 | ||
| ) # 1-hour TTL | ||
| # OAuth2 token caching (55-min TTL with 5-min safety buffer before 1-hour token expiry) | ||
| async_oauth_token_cache: "AsyncCache" = _AsyncBackend(ttl_seconds=3300) # 55-min TTL |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's be conservative about this and use 40 mins TTL here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think 55 minutes is more reasonable if we add the refresh token mechanism
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've reconsidered the TTL setup and now believe we should avoid using a fixed TTL in favor of storing an explicit expiration timestamp with each token. This way, we can validate tokens by comparing the stored expires_at against the current time.
Using both TTL and the token’s own expiration can introduce unnecessary complexity and inconsistencies. For example, if the Redis TTL is set to 40 minutes but the token itself is valid for 60 minutes, we’d lose the cache prematurely at 41 minutes, causing an unnecessary token refresh. Conversely, if the token expires at 60 minutes but the cache TTL is longer, we might end up using an invalid token that appears “alive” in cache.
By tracking expires_at explicitly, we can avoid these edge cases and handle token refresh more reliably through expiration checks and error-based retries.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is good. One thing I am concerned about is cleaning up old tokens. Could cause memory leaks.
If we do this:
async_oauth_token_cache: "AsyncCache" = _AsyncBackend(ttl_seconds=None) # No TTL
Some potential Problems:
- Expired tokens accumulate forever → memory leak
- Cache grows indefinitely
- No automatic cleanup mechanism
I think we can do this though:
In vertex_adapter.py
token_data = {
"access_token": credentials.token,
"token_type": "Bearer",
"expires_at": credentials.expiry.timestamp(), # Unix timestamp
"scope": "https://www.googleapis.com/auth/cloud-platform",
"cached_at": time.time(), # For debugging
"provider": "vertex" # Helpful for multi-provider systems
}
using expires_at for cleaning up tokens:
async_oauth_token_cache: "AsyncCache" = _AsyncBackend(ttl_seconds=None)
async def get_cached_oauth_token_async(api_key: str) -> dict[str, Any] | None:
if not api_key:
return None
import hashlib, time
cache_key = f"token:{hashlib.sha256(api_key.encode()).hexdigest()}"
cached_data = await async_oauth_token_cache.get(cache_key)
if not cached_data:
return None
expires_at = cached_data.get("expires_at")
if not expires_at:
await async_oauth_token_cache.delete(cache_key)
return None
current_time = time.time()
if expires_at <= current_time:
await async_oauth_token_cache.delete(cache_key)
await _opportunistic_cleanup(current_time, max_items=2)
return None
return cached_data
async def _opportunistic_cleanup(current_time: float, max_items: int = 2):
cleaned = 0
if hasattr(async_oauth_token_cache, "cache"):
for key, value in list(async_oauth_token_cache.cache.items()):
if cleaned >= max_items:
break
if key.startswith("token:"):
expires_at = value.get("expires_at")
if expires_at and expires_at <= current_time:
await async_oauth_token_cache.delete(key)
cleaned += 1
elif hasattr(async_oauth_token_cache, "client"):
try:
import os
pattern = f"{os.getenv('REDIS_PREFIX', 'forge')}:token:*"
async for redis_key in async_oauth_token_cache.client.scan_iter(match=pattern, count=10):
if cleaned >= max_items:
break
key_str = redis_key.decode() if isinstance(redis_key, bytes) else redis_key
internal_key = key_str.split(":", 1)[-1]
cached_data = await async_oauth_token_cache.get(internal_key)
if cached_data:
expires_at = cached_data.get("expires_at")
if expires_at and expires_at <= current_time:
await async_oauth_token_cache.delete(internal_key)
cleaned += 1
except Exception:
pass
async def cache_oauth_token_async(api_key: str, token_data: dict[str, Any]) -> None:
if not api_key or not token_data:
return
import hashlib
cache_key = f"token:{hashlib.sha256(api_key.encode()).hexdigest()}"
if "expires_at" not in token_data:
logger.warning("OAuth token cached without expires_at - skipping")
return
await async_oauth_token_cache.set(cache_key, token_data)
| @@ -118,6 +118,8 @@ def stats(self) -> dict[str, Any]: | |||
| # Expose the global cache instances | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We're working on a PR to fully switch to async and would discard the sync cache very soon. We don't need to support this any more.
| self.parse_api_key(api_key) | ||
|
|
||
| # check cache first for existing valid token | ||
| cached_token = await get_cached_oauth_token_async(api_key) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I actually thought about cached oauth2 token while implementing this. My biggest concern is that if for some reason, this cached token is invalided before the expiry time. There is no way for users to manually invalidate the cache and trigger a new one token generation. They would have to wait for 40-50 minutes for the cache to expire.
And if we choose to run a simple connection test for the token before . The performance would degrade and fall back to the uncached version. Any suggestion?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe Vertex follows the standard OAuth 2.0 protocol, which returns two key tokens upon authorization:
- Access Token – Used to authenticate user access to resources (e.g., Vertex APIs). It typically has a short time-to-live (TTL), such as 1 hour.
- Refresh Token – Used solely to obtain a new access token once the original expires. It enables seamless long-term authentication without requiring the user to log in again.
The typical workflow looks like this:
[User Logs In / Grants Access] → [Authorization Server returns access_token + refresh_token] → access_token is used to call Vertex APIs → refresh_token is used to obtain a new access_token when needed.
So ideally the token logic would be:
def get_vertex_token(api_key):
token_data = redis.get(f"vertex_token:{api_key}")
if not token_data:
raise Unauthorized("Vertex not bound")
# if the access token is not expired
if token_data["expires_at"] > current_timestamp():
return token_data["access_token"]
# if there has refresh_token, then refresh
if "refresh_token" in token_data:
new_token = refresh_vertex_token(token_data["refresh_token"])
redis.set(f"vertex_token:{api_key}", new_token, ttl=3600)
return new_token["access_token"]
# can not refresh, not to re-login
raise TokenExpired("Please reconnect your Vertex account")
As the following step in next PR, we should store the refresh token by fernet.encryption(refresh_token) into postgre db with a new table names something like: user_token etc. (Please create a new issues as the record if you guys agree my suggestion)
also, I suggest the access token cache json design should be also include 1. scope, 2. token_type ("Bearer" etc).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think for this case, we can do straightforward and more easily, just always use user's JSON file to re-authenticate if the current access token is expired or cleanup, let's just forget about refresh token at this stage
| "token": credentials.token, | ||
| "expiry": credentials.expiry.isoformat() | ||
| } | ||
| await cache_oauth_token_async(api_key, token_data) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unlike other providers, the api key for the oauth2 is a long json string. It's not suited for cache key. Let's use a hash function to generate the cache key
|
Thanks so much for this optimization, very nice work overall |
|
I think this PR should be rebased on the database and async cache refactor from #16. |
04d6d4d to
0d20eaa
Compare
- Add OAuth2 token caching functions to async_cache.py and cache.py
- Create async_oauth_token_cache with 55-minute TTL (5-min safety buffer)
- Update vertex_authentication method to check cache first before token refresh
- Add OAuth2 token cache statistics tracking
- Cache key format: token:{api_key}
- Performance optimization: reduce unnecessary token refresh calls
* Replace fixed TTL with token's native expires_at timestamp validation * Use SHA-256 hashed cache keys for long JSON service account credentials * Add opportunistic cleanup to prevent expired token memory leaks * Standardize token structure with access_token, expires_at, token_type fields * Simplify Vertex AI authentication using service account re-authentication * Move imports to top-level for better code quality Addresses PR feedback on cache key security, TTL complexity, and memory management.
|
I am going to close this PR and make a PR to the main branch, as there seems to be conflicts with the vertex_support branch. |
Let me know if this implementation is sufficient. All the test cases that I ran worked.